{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## 4.1.2 Embedding层的实现"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-06T12:05:06.841341100Z",
     "start_time": "2023-05-06T12:05:06.684667900Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "W = np.arange(21).reshape(7, 3)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-06T12:05:20.168362900Z",
     "start_time": "2023-05-06T12:05:20.157361500Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[ 0,  1,  2],\n       [ 3,  4,  5],\n       [ 6,  7,  8],\n       [ 9, 10, 11],\n       [12, 13, 14],\n       [15, 16, 17],\n       [18, 19, 20]])"
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-06T12:05:21.630605600Z",
     "start_time": "2023-05-06T12:05:21.615606400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "data": {
      "text/plain": "array([6, 7, 8])"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W[2]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-06T12:32:20.345637500Z",
     "start_time": "2023-05-06T12:32:20.341638600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[ 3,  4,  5],\n       [ 0,  1,  2],\n       [ 9, 10, 11],\n       [ 0,  1,  2]])"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = np.array([1, 0, 3, 0])\n",
    "W[idx]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-06T12:33:13.850460700Z",
     "start_time": "2023-05-06T12:33:13.822464200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "class Embedding:\n",
    "    def __init__(self, W):\n",
    "        self.params = [W]\n",
    "        self.grads = [np.zeros_like(W)]\n",
    "        self.idx = None\n",
    "\n",
    "    def forward(self, idx):\n",
    "        W, = self.params\n",
    "        self.idx = idx\n",
    "        out = W[idx]\n",
    "        return out\n",
    "\n",
    "    def backward(self, dout):\n",
    "        dW, = self.grads\n",
    "        dW[...] = 0\n",
    "\n",
    "        for i, word_id in enumerate(self.idx):\n",
    "            dW[word_id] += dout[i]\n",
    "        # 或者\n",
    "        # np.add.at(dW, self.idx, dout)\n",
    "        return None"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-06T12:54:10.763713100Z",
     "start_time": "2023-05-06T12:54:10.753712900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
